{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4a6ab9a2-28a2-445d-8512-a0dc8d1b54e9",
   "metadata": {},
   "source": [
    "# Unit test Generator\n",
    "\n",
    "Create unit tests on the Python code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e610bf56-a46e-4aff-8de1-ab49d62b1ad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "import os\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "from google import genai\n",
    "from google.genai import types\n",
    "import anthropic\n",
    "import ollama\n",
    "import gradio as gr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f672e1c-87e9-4865-b760-370fa605e614",
   "metadata": {},
   "outputs": [],
   "source": [
    "# environment\n",
    "\n",
    "load_dotenv(override=True)\n",
    "os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
    "os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n",
    "os.environ['GOOGLE_API_KEY'] = os.getenv('GOOGLE_API_KEY', 'your-key-if-not-using-env')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aa149ed-9298-4d69-8fe2-8f5de0f667da",
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize\n",
    "\n",
    "openai = OpenAI()\n",
    "claude = anthropic.Anthropic()\n",
    "client = genai.Client()\n",
    "\n",
    "\n",
    "OPENAI_MODEL = \"gpt-4o\"\n",
    "CLAUDE_MODEL = \"claude-sonnet-4-20250514\"\n",
    "GEMINI_MODEL = 'gemini-2.5-flash'\n",
    "LLAMA_MODEL = \"llama3.2\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6896636f-923e-4a2c-9d6c-fac07828a201",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_message = \"\"\"\n",
    "You are an effective programming assistant specialized to generate Python code based on the inputs.\n",
    "Respond only with Python code; use comments sparingly and do not provide any explanation other than occasional comments.\n",
    "Do not include Markdown formatting (```), language tags (python), or extra text \\n.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e7b3546-57aa-4c29-bc5d-f211970d04eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def user_prompt_for_unit_test(python):\n",
    "    user_prompt = f\"\"\"\n",
    "    Consider the following Python code: \\n\\n\n",
    "        {python} \\n\\n\n",
    "\n",
    "        Generate a unit test around this code and it alongside with the Python code. \\n\n",
    "        Response rule: in your response do not include Markdown formatting (```), language tags (python), or extra text.\n",
    "\n",
    "    \"\"\"\n",
    "    return user_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6190659-f54c-4951-bef4-4960f8e51cc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def messages_for_unit_test(python):\n",
    "    return [\n",
    "        {\"role\": \"system\", \"content\": system_message},\n",
    "        {\"role\": \"user\", \"content\": user_prompt_for_unit_test(python)}\n",
    "    ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3b497b3-f569-420e-b92e-fb0f49957ce0",
   "metadata": {},
   "outputs": [],
   "source": [
    "python_hard = \"\"\"\n",
    "\n",
    "def lcg(seed, a=1664525, c=1013904223, m=2**32):\n",
    "    value = seed\n",
    "    while True:\n",
    "        value = (a * value + c) % m\n",
    "        yield value\n",
    "\n",
    "def max_subarray_sum(n, seed, min_val, max_val):\n",
    "    lcg_gen = lcg(seed)\n",
    "    random_numbers = [next(lcg_gen) % (max_val - min_val + 1) + min_val for _ in range(n)]\n",
    "    max_sum = float('-inf')\n",
    "    for i in range(n):\n",
    "        current_sum = 0\n",
    "        for j in range(i, n):\n",
    "            current_sum += random_numbers[j]\n",
    "            if current_sum > max_sum:\n",
    "                max_sum = current_sum\n",
    "    return max_sum\n",
    "\n",
    "def total_max_subarray_sum(n, initial_seed, min_val, max_val):\n",
    "    total_sum = 0\n",
    "    lcg_gen = lcg(initial_seed)\n",
    "    for _ in range(20):\n",
    "        seed = next(lcg_gen)\n",
    "        total_sum += max_subarray_sum(n, seed, min_val, max_val)\n",
    "    return total_sum\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0be9f47d-5213-4700-b0e2-d444c7c738c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def stream_gpt(python):\n",
    "    stream = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages_for_unit_test(python), stream=True)\n",
    "    reply = \"\"\n",
    "    for chunk in stream:\n",
    "        fragment = chunk.choices[0].delta.content or \"\"\n",
    "        reply += fragment\n",
    "        yield reply"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8669f56b-8314-4582-a167-78842caea131",
   "metadata": {},
   "outputs": [],
   "source": [
    "def stream_claude(python):\n",
    "    result = claude.messages.stream(\n",
    "        model=CLAUDE_MODEL,\n",
    "        max_tokens=2000,\n",
    "        system=system_message,\n",
    "        messages=[{\"role\": \"user\", \"content\": user_prompt_for_unit_test(python)}],\n",
    "    )\n",
    "    reply = \"\"\n",
    "    with result as stream:\n",
    "        for text in stream.text_stream:\n",
    "            reply += text\n",
    "            yield reply"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97205162",
   "metadata": {},
   "outputs": [],
   "source": [
    "def stream_gemini(python):\n",
    "    response = client.models.generate_content_stream(\n",
    "    model=GEMINI_MODEL,\n",
    "    config=types.GenerateContentConfig(\n",
    "        system_instruction=system_message),\n",
    "    contents=user_prompt_for_unit_test(python)\n",
    "    )\n",
    "\n",
    "    reply = \"\"\n",
    "    for chunk in response:\n",
    "        fragment = chunk.text or \"\"\n",
    "        reply += fragment\n",
    "        yield reply\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f94b13e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def stream_llama_local(python):\n",
    "    stream = ollama.chat(\n",
    "        model='llama3.2',\n",
    "        messages=messages_for_unit_test(python),\n",
    "        stream=True,\n",
    "    )\n",
    "\n",
    "    reply = \"\"\n",
    "    # Iterate through the streamed chunks and print the content\n",
    "    for chunk in stream:\n",
    "        #print(chunk['message']['content'], end='', flush=True)\n",
    "        if 'content' in chunk['message']:\n",
    "            fragment = chunk['message']['content']\n",
    "            reply += fragment\n",
    "            yield reply\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f1ae8f5-16c8-40a0-aa18-63b617df078d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_unit_test(python, model):\n",
    "    if model==\"GPT\":\n",
    "        result = stream_gpt(python)\n",
    "    elif model==\"Claude\":\n",
    "        result = stream_claude(python)\n",
    "    elif model==\"Gemini\":\n",
    "        result = stream_gemini(python)\n",
    "    elif model==\"Llama\":\n",
    "        result = stream_llama_local(python)\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"Unknown model\")\n",
    "    for stream_so_far in result:\n",
    "        yield stream_so_far"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1ddb38e-6b0a-4c37-baa4-ace0b7de887a",
   "metadata": {},
   "outputs": [],
   "source": [
    "with gr.Blocks() as ui:\n",
    "    with gr.Row():\n",
    "        python = gr.Textbox(label=\"Python code:\", lines=10, value=python_hard)\n",
    "        unit_test = gr.Textbox(label=\"Unit test\", lines=10)\n",
    "    with gr.Row():\n",
    "        model = gr.Dropdown([\"GPT\", \"Claude\", \"Gemini\", \"Llama\"], label=\"Select model\", value=\"GPT\")\n",
    "        generate_ut = gr.Button(\"Generate Unit tests\")\n",
    "\n",
    "    generate_ut.click(generate_unit_test, inputs=[python, model], outputs=[unit_test])\n",
    "\n",
    "ui.launch(inbrowser=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
